import paddle
import numpy as np

class SSLoss():
    def __init__(self, num_classes=20, out_strides=[32, 16, 8]):
        """
        初始化损失函数
        * 参数:
        - num_classes: 物体类别数量
        - out_strides: 输出下采样率
        """
        super().__init__()
        
        self.num_classes = num_classes
        self.out_strides = out_strides

    def __call__(self, p_list, inputs):
        """
        计算预测损失
        * 参数:
        - p_list: 预测特征
        - inputs: 输入数据
        * 返回:
        - losses: 预测损失
        """
        # 计算预测标签
        labels = self.get_labels(p_list, inputs)

        # 计算预测损失
        losses = self.get_losses(p_list, labels)

        return losses

#########################################################################################

    @paddle.no_grad()
    def get_labels(self, p_list, inputs):
        """
        计算预测标签
        * 参数:
        - p_list: 预测特征
        - inputs: 输入数据
        * 返回:
        - labels: 预测标签
        """
        # 读取输入数据
        images = inputs['image'] # 图像数据
        imgs_b = images.shape[0] # 图像批次
        imgs_h = images.shape[2] # 图像高度
        imgs_w = images.shape[3] # 图像宽度
        gt_box = inputs['gtbox'] # 物体边框
        gt_cls = inputs['gtcls'] # 物体类别

        # 计算预测标签
        labels = [] # 预测标签列表
        for infer, stride in zip(p_list, self.out_strides): # 遍历网络输出
            # 读取预测特征
            pd_loc = infer[:,  :4, :, :] # 预测位置特征，形状为[b,4,h,w]
            pd_obj = infer[:, 4:5, :, :] # 预测物体特征，形状为[b,1,h,w]
            pd_cls = infer[:, 5: , :, :] # 预测类别特征，形状为[b,c,h,w]

            # 计算预测标签
            lbbox_list = [] # 预测边框列表
            lbobj_list = [] # 预测物体列表
            lbcls_list = [] # 预测类别列表
            label_dict = {} # 预测标签字典
            for b in range(imgs_b): # 遍历图像批次
                # 计算物体掩码
                omask, nmask = self.get_masks(gt_box[b], stride, imgs_w, imgs_h)

                # 计算物体边框
                gtbox, pdbox = self.get_gpbox(gt_box[b], pd_loc[b], omask, imgs_w, imgs_h)

                # 计算物体置信
                gtobj, pdobj = self.get_gpobj(gt_box[b], pd_obj[b], omask)

                # 计算物体类别
                gtcls, pdcls = self.get_gpcls(gt_box[b], gt_cls[b], pd_cls[b], omask)

                # 计算交并比值
                gpiou = self.get_gpiou(gtbox, pdbox)

                # 计算预测代价
                costs = self.get_costs(gpiou, gtobj, pdobj, gtcls, pdcls, nmask)

                # 计算候选掩码
                cmask = self.get_cmask(gpiou, costs)

                # 计算预测标签
                lbbox, lbobj, lbcls = self.get_label(gtbox, gtcls, cmask)

                # 添加预测标签
                lbbox_list.append(lbbox) # 添加边框列表
                lbobj_list.append(lbobj) # 添加物体列表
                lbcls_list.append(lbcls) # 添加类别列表

            # 添加标签列表
            label_dict['lbbox'] = paddle.stack(lbbox_list, axis=0) # 设置边框字典，形状为[b,4,h,w]
            label_dict['lbobj'] = paddle.stack(lbobj_list, axis=0) # 设置物体字典，形状为[b,h,w]
            label_dict['lbcls'] = paddle.stack(lbcls_list, axis=0) # 设置类别字典，形状为[b,c,h,w]
            labels.append(label_dict)                              # 添加标签列表
        
        return labels

    def get_masks(self, gt_box, stride, imgs_w, imgs_h):
        """
        计算物体掩码
        * 输入:
        - gt_box:真实边框
        - stride:下采样率
        - imgs_w:图像宽度
        - imgs_h:图像高度
        * 返回:
        - omask :或集掩码
        - nmask :非集掩码
        """
        # 读取数量坐标
        gtnum = int((gt_box.numpy().sum(axis=-1)>0).sum()) # 真实边框数量
        gtbox = gt_box[:gtnum]                             # 真实边框坐标，形状为[n,4]

        # 计算网格坐标
        grid_h = int(imgs_h / stride) # 网格高度
        grid_w = int(imgs_w / stride) # 网格宽度

        yv, xv = paddle.meshgrid([paddle.arange(grid_h), paddle.arange(grid_w)])
        grid_o = paddle.stack([xv, yv], axis=-1).astype('float32') # 网格原点，形状为[h,w,2]
        grid_c = grid_o + 0.5                                      # 网格中心，形状为[h,w,2]
        grid_x = grid_c[:, :, 0].unsqueeze(0) / grid_w             # 网格坐标，形状为[1,h,w]
        grid_y = grid_c[:, : ,1].unsqueeze(0) / grid_h             # 网格坐标，形状为[1,h,w]

        # 计算中心坐标
        radius = 2.5                                       # 设置中心半径
        center_x1 = gtbox[:, 0] - radius * stride / imgs_w # 中心左边坐标
        center_y1 = gtbox[:, 1] - radius * stride / imgs_h # 中心上边坐标
        center_x2 = gtbox[:, 0] + radius * stride / imgs_w # 中心右边坐标
        center_y2 = gtbox[:, 1] + radius * stride / imgs_h # 中心下边坐标

        center_x1 = center_x1.unsqueeze([1, 2]) # 变换中心坐标，形状为[n,1,1]
        center_y1 = center_y1.unsqueeze([1, 2]) # 变换中心坐标，形状为[n,1,1]
        center_x2 = center_x2.unsqueeze([1, 2]) # 变换中心坐标，形状为[n,1,1]
        center_y2 = center_y2.unsqueeze([1, 2]) # 变换中心坐标，形状为[n,1,1]

        # 计算中心距离
        center_ld = grid_x - center_x1 # 中心左边距离，形状为[n,h,w]
        center_td = grid_y - center_y1 # 中心上边距离，形状为[n,h,w]
        center_rd = center_x2 - grid_x # 中心右边距离，形状为[n,h,w]
        center_bd = center_y2 - grid_y # 中心下边距离，形状为[n,h,w]

        # 计算中心掩码
        center_dist = paddle.stack(
            [center_ld, center_td, center_rd, center_bd], axis=-1) # 中心距离，形状为[n,h,w,4]
        center_mask = center_dist.min(axis=-1) > 0                 # 中心掩码，形状为[n,h,w]
        center_smsk = center_mask.sum(axis=0) > 0                  # 中心掩码，形状为[h,w]

        # 计算边界坐标
        scales = 0.5                                   # 设置缩放比例
        border_x1 = gtbox[:, 0] - scales * gtbox[:, 2] # 边界左边坐标
        border_y1 = gtbox[:, 1] - scales * gtbox[:, 3] # 边界上边坐标
        border_x2 = gtbox[:, 0] + scales * gtbox[:, 2] # 边界右边坐标
        border_y2 = gtbox[:, 1] + scales * gtbox[:, 3] # 边界下边坐标

        border_x1 = border_x1.unsqueeze([1, 2]) # 变换边界坐标，形状为[n,1,1]
        border_y1 = border_y1.unsqueeze([1, 2]) # 变换边界坐标，形状为[n,1,1]
        border_x2 = border_x2.unsqueeze([1, 2]) # 变换边界坐标，形状为[n,1,1]
        border_y2 = border_y2.unsqueeze([1, 2]) # 变换边界坐标，形状为[n,1,1]

        # 计算边界距离
        border_ld = grid_x - border_x1 # 边界左边距离，形状为[n,h,w]
        border_td = grid_y - border_y1 # 边界上边距离，形状为[n,h,w]
        border_rd = border_x2 - grid_x # 边界右边距离，形状为[n,h,w]
        border_bd = border_y2 - grid_y # 边界下边距离，形状为[n,h,w]

        # 计算边界掩码
        border_dist = paddle.stack(
            [border_ld, border_td, border_rd, border_bd], axis=-1) # 边界距离，形状为[n,h,w,4]
        border_mask = border_dist.min(axis=-1) > 0                 # 边界掩码，形状为[n,h,w]
        border_smsk = border_mask.sum(axis=0) > 0                  # 边界掩码，形状为[h,w]

        # 计算物体掩码
        # omask = center_smsk | border_smsk                  # 或集物体掩码，形状为[h,w]
        # nmask = ~ (center_mask & border_mask)              # 非集物体掩码，形状为[n,h,w]
        omask = paddle.logical_or(center_smsk, border_smsk)  # 或集物体掩码，形状为[h,w]
        amask = paddle.logical_and(center_mask, border_mask) # 与集物体掩码，形状为[n,h,w]
        nmask = paddle.logical_not(amask)                    # 非集物体掩码，形状为[n,h,w]

        # 转换物体掩码
        omask = omask.astype('float32') # 或集物体掩码，形状为[h,w]
        nmask = nmask.astype('float32') # 非集物体掩码，形状为[n,h,w]

        return omask, nmask

    def get_gpbox(self, gt_box, pd_loc, omask, imgs_w, imgs_h):
        """
        计算物体边框
        * 参数:
        - gt_box:真实边框
        - pd_loc:预测位置
        - omask :或集掩码
        - imgs_w:图像宽度
        - imgs_h:图像高度
        * 返回:
        - gtbox :真实边框
        - pdbox :预测边框
        """
        # 计算数量掩码
        gtnum = int((gt_box.numpy().sum(axis=-1)>0).sum()) # 真实边框数量
        omask = omask.unsqueeze([0, 1])                    # 或集物体掩码，形状为[1,1,h,w]

        # 计算网格坐标
        grid_h = omask.shape[-2] # 网格高度
        grid_w = omask.shape[-1] # 网格宽度

        yv, xv = paddle.meshgrid([paddle.arange(grid_h), paddle.arange(grid_w)])
        grid_o = paddle.stack([xv, yv], axis=-1).astype('float32') # 网格原点，形状为[h,w,2]
        grid_x = grid_o[:, :, 0]                                   # 网格坐标，形状为[h,w]
        grid_y = grid_o[:, : ,1]                                   # 网格坐标，形状为[h,w]

        # 计算真实边框
        gtbox = gt_box[:gtnum]                                       # 读取真实边框，形状为[n,4]
        gtbox = gtbox.unsqueeze([2, 3]).tile([1, 1, grid_h, grid_w]) # 变换真实边框，形状为[n,4,h,w]
        gtbox = gtbox * omask                                        # 计算真实边框，形状为[n,4,h,w]

        # 计算预测边框
        pd_tx = pd_loc[0, :, :] # 预测位置中心，形状为[h,w]
        pd_ty = pd_loc[1, :, :] # 预测位置中心，形状为[h,w]
        pd_tw = pd_loc[2, :, :] # 预测位置宽度，形状为[h,w]
        pd_th = pd_loc[3, :, :] # 预测位置高度，形状为[h,w]

        pd_px = (grid_x + paddle.nn.functional.sigmoid(pd_tx)) / grid_w # px=cx+sigmoid(tx)
        pd_py = (grid_y + paddle.nn.functional.sigmoid(pd_ty)) / grid_h # py=cy+sigmoid(ty)
        pd_pw = paddle.exp(pd_tw) / imgs_w                              # pw=exp(tw)
        pd_ph = paddle.exp(pd_th) / imgs_h                              # ph=exp(th)

        pdbox = paddle.stack([pd_px, pd_py, pd_pw, pd_ph], axis=0) # 计算预测边框，形状为[4,h,w]
        pdbox = pdbox.unsqueeze(0).tile([gtnum, 1, 1, 1])          # 变换预测边框，形状为[n,4,h,w]
        pdbox = pdbox * omask                                      # 计算预测边框，形状为[n,4,h,w]

        return gtbox, pdbox

    def get_gpobj(self, gt_box, pd_obj, omask):
        """
        计算物体置信
        * 输入:
        - gt_box:真实边框
        - pd_obj:预测物体
        - omask :或集掩码
        * 返回:
        - gtobj :真实物体
        - pdobj :预测物体
        """
        # 计算数量掩码
        gtnum = int((gt_box.numpy().sum(axis=-1)>0).sum()) # 真实物体数量
        omask = omask.unsqueeze(0)                         # 或集物体掩码，形状为[1,h,w]

        # 计算真实置信
        gtobj = omask.tile([gtnum, 1, 1])            # 计算真实置信，形状为[n,h,w]

        # 计算预测置信
        pdobj = paddle.nn.functional.sigmoid(pd_obj) # 计算预测置信，形状为[1,h,w]
        pdobj = pdobj.tile([gtnum, 1, 1])            # 变换预测置信，形状为[n,h,w]
        pdobj = pdobj * omask                        # 计算预测置信，形状为[n,h,w]

        return gtobj, pdobj

    def get_gpcls(self, gt_box, gt_cls, pd_cls, omask):
        """
        计算物体类别
        * 输入:
        - gt_box:真实边框
        - gt_cls:真实类别
        - pd_cls:预测类别
        - omask :或集掩码
        * 返回:
        - gtcls :真实类别
        - pdcls :预测类别
        """
        # 计算数量掩码
        gtnum = int((gt_box.numpy().sum(axis=-1)>0).sum()) # 真实物体数量
        omask = omask.unsqueeze([0, 1])                    # 或集物体掩码，形状为[1,1,h,w]

        # 计算真实类别
        grid_h = omask.shape[-2] # 网格高度
        grid_w = omask.shape[-1] # 网格宽度

        gtcls = gt_cls[:gtnum]                                        # 读取真实类别，形状为[n,1]
        gtcls = paddle.nn.functional.one_hot(gtcls, self.num_classes) # 设置真实类别，形状为[n,c]
        gtcls = gtcls.unsqueeze([2, 3]).tile([1, 1, grid_h, grid_w])  # 变换真实类别，形状为[n,c,h,w]
        gtcls = gtcls * omask                                         # 计算真实类别，形状为[n,c,h,w]

        # 计算预测类别
        pdcls = paddle.nn.functional.sigmoid(pd_cls)      # 计算预测类别，形状为[c,h,w]
        pdcls = pdcls.unsqueeze(0).tile([gtnum, 1, 1, 1]) # 变换预测类别，形状为[n,c,h,w]
        pdcls = pdcls * omask                             # 计算预测类别，形状为[n,c,h,w]

        return gtcls, pdcls

    def get_gpiou(self, gtbox, pdbox):
        """
        计算交并比值
        * 参数:
        - gtbox:真实边框
        - pdbox:预测边框
        * 返回:
        - gpiou:交并比值
        """
        # 计算真实坐标
        gt_x1 = gtbox[:, 0, :, :] - 0.5 * gtbox[:, 2, :, :] # gtbox的x1坐标，形状为[n,h,w]
        gt_y1 = gtbox[:, 1, :, :] - 0.5 * gtbox[:, 3, :, :] # gtbox的y1坐标，形状为[n,h,w]
        gt_x2 = gtbox[:, 0, :, :] + 0.5 * gtbox[:, 2, :, :] # gtbox的x2坐标，形状为[n,h,w]
        gt_y2 = gtbox[:, 1, :, :] + 0.5 * gtbox[:, 3, :, :] # gtbox的y2坐标，形状为[n,h,w]

        # 计算预测坐标
        pd_x1 = pdbox[:, 0, :, :] - 0.5 * pdbox[:, 2, :, :] # pdbox的x1坐标，形状为[n,h,w]
        pd_y1 = pdbox[:, 1, :, :] - 0.5 * pdbox[:, 3, :, :] # pdbox的y1坐标，形状为[n,h,w]
        pd_x2 = pdbox[:, 0, :, :] + 0.5 * pdbox[:, 2, :, :] # pdbox的x2坐标，形状为[n,h,w]
        pd_y2 = pdbox[:, 1, :, :] + 0.5 * pdbox[:, 3, :, :] # pdbox的y2坐标，形状为[n,h,w]
        
        # 计算交集面积
        x1 = paddle.maximum(gt_x1, pd_x1) # gtbox与pdbox的x1坐标，形状为[n,h,w]
        y1 = paddle.maximum(gt_y1, pd_y1) # gtbox与pdbox的y1坐标，形状为[n,h,w]
        x2 = paddle.minimum(gt_x2, pd_x2) # gtbox与pdbox的x2坐标，形状为[n,h,w]
        y2 = paddle.minimum(gt_y2, pd_y2) # gtbox与pdbox的y2坐标，形状为[n,h,w]

        intersection = (x2 - x1).clip(0) * (y2 - y1).clip(0) # pdbox与gtbox的交集面积，形状为[n,h,w]
        
        # 计算并集面积
        s1 = (gt_x2 - gt_x1).clip(0) * (gt_y2 - gt_y1).clip(0) # gtbox的面积，形状为[n,h,w]
        s2 = (pd_x2 - pd_x1).clip(0) * (pd_y2 - pd_y1).clip(0) # pdbox的面积，形状为[n,h,w]
        
        union = s1 + s2 - intersection + 1e-9 # gtbox与pdbox的并集面积，形状为[n,h,w]
        
        # 计算交并比值
        gpiou = intersection / union # gtbox与pdbox交并比值，形状为[n,h,w]

        return gpiou

    def get_costs(self, gpiou, gtobj, pdobj, gtcls, pdcls, nmask):
        """
        计算预测代价
        * 输入:
        - gpiou:交并比值
        - gtobj:真实置信
        - pdobj:预测置信
        - gtcls:真实类别
        - pdcls:预测类别
        - nmask:非集掩码
        * 返回:
        - costs:预测代价
        """
        # 计算交并损失
        loss_wgt = 5.0                        # 交并比值权重      
        loss_iou = - paddle.log(gpiou + 1e-9) # 交并比值损失，形状为[n,h,w]
        loss_iou = loss_wgt * loss_iou        # 交并比值损失，形状为[n,h,w]

        # 计算物体损失
        loss_obj = paddle.nn.functional.binary_cross_entropy(
            pdobj, gtobj, reduction="none") # 物体置信损失，形状为[n,h,w]

        # 计算类别损失
        loss_cls = paddle.nn.functional.binary_cross_entropy(
            pdcls, gtcls, reduction="none") # 物体类别损失，形状为[n,c,h,w]
        loss_cls = loss_cls.sum(axis=1)     # 物体类别损失，形状为[n,h,w]

        # 计算负例损失
        loss_wgt = 100000.0         # 负例损失权重      
        loss_neg = loss_wgt * nmask # 物体负例损失，形状为[n,h,w]

        # 计算预测代价
        costs = loss_iou + loss_obj + loss_cls + loss_neg # 预测代价，形状为[n,h,w]

        return costs

    def get_cmask(self, gpiou, costs):
        """
        计算候选掩码
        * 输入:
        - gpiou :交并比值
        - costs :预测代价
        * 返回:
        - cmask :候选掩码
        """
        # 计算候选数量
        grid_h = gpiou.shape[-2] # 网格高度
        grid_w = gpiou.shape[-1] # 网格宽度

        topk_iou = gpiou.reshape([-1, grid_h * grid_w])                    # 变换交并比值，形状为[n,h*w]
        topk_num = 10                                                      # 设置候选数量
        topk_box, _ = paddle.topk(topk_iou, topk_num, axis=1)              # 计算降序边框，形状为[n,10]
        topk_num = topk_box.sum(axis=1).clip(min=1).astype("int").tolist() # 计算候选数量

        # 计算候选掩码
        costs = costs.reshape([-1, grid_h * grid_w]) # 转换代价矩阵，形状为[n,h*w]
        cmask = paddle.zeros_like(costs)             # 设置候选掩码，形状为[n,h*w]
        for j in range(costs.shape[0]): # 遍历物体数量
            _ , k = paddle.topk(costs[j], topk_num[j], largest=False) # 从小到大索引
            cmask[j, k] = 1.0                                         # 设置候选掩码,形状为[n,h*w]
                
        # 过滤共用掩码
        costs = costs.numpy()           # 转换物体代价，形状为[n,h*w]
        cmask = cmask.numpy()           # 转换候选掩码，形状为[n,h*w]
        smask = cmask.sum(axis=0) > 1.0 # 计算共用掩码，形状为[h*w]
        if smask.sum() > 0:             # 是否存在共用
            j = np.argmin(costs[:, smask], axis=0) # 计算最小索引
            cmask[:, smask] *= 0.0                 # 清零共用掩码
            cmask[j, smask]  = 1.0                 # 设置候选掩码

        cmask = paddle.to_tensor(cmask)             # 转换候选掩码，形状为[n,h*w]
        cmask = cmask.reshape([-1, grid_h, grid_w]) # 变换掩码形状，形状为[n,h,w]

        return cmask

    def get_label(self, gtbox, gtcls, cmask):
        """
        计算预测标签
        * 参数:
        - gtbox:真实边框
        - gtcls:真实类别
        - cmask:候选掩码
        * 返回:
        - lbbox:边框标签
        - lbobj:物体标签
        - lbcls:类别标签
        """
        # 计算物体边框
        lbbox = gtbox * cmask.unsqueeze(1) # 设置边框标签，形状为[n,4,h,w]
        lbbox = lbbox.sum(axis=0)          # 计算边框标签，形状为[4,h,w]

        # 计算物体标签
        lbobj = cmask             # 设置物体标签，形状为[n,h,w]
        lbobj = lbobj.sum(axis=0) # 计算物体标签，形状为[h,w]

        # 计算类别标签
        lbcls = gtcls * cmask.unsqueeze(1) # 设置类别标签，形状为[n,c,h,w]
        lbcls = lbcls.sum(axis=0)          # 计算类别标签，形状为[c,h,w]

        return lbbox, lbobj, lbcls

#########################################################################################

    def get_losses(self, p_list, labels):
        """
        计算预测损失
        * 参数:
        - p_list: 预测特征
        - labels: 预测标签
        * 返回:
        - losses: 预测损失
        """
        # 计算预测损失
        losses = 0 # 预测损失总和
        for infer, label, stride in zip(p_list, labels, self.out_strides): # 遍历网络输出
            # 计算交并损失
            pdloc = infer[:, :4, :, :]            # 预测位置特征，形状为[b,4,h,w]
            lbbox = label['lbbox']                # 预测边框标签，形状为[b,4,h,w]
            pmask = label['lbobj']                # 正例物体掩码，形状为[b,h,w] 
            pdbox = self.get_pdbox(pdloc, stride) # 计算预测边框，形状为[b,4,h,w]
            
            loss_wgt = 5.0                              # 交并比值权重
            loss_iou = 1 - self.get_gpiou(lbbox, pdbox) # 计算交并比值，形状为[b,h,w]
            loss_iou = loss_wgt * loss_iou * pmask      # 只计算正样本，形状为[b,h,w]

            # 计算物体损失
            pdobj = infer[:, 4, :, :] # 预测物体特征，形状为[b,h,w]，计算交叉熵时会求sigmoid
            lbobj = label['lbobj']    # 预测物体标签，现状为[b,h,w]，计算交叉熵不会求sigmoid

            loss_obj = paddle.nn.functional.binary_cross_entropy_with_logits(
                pdobj, lbobj, reduction='none') # 计算物体损失，形状为[b,h,w]

            # 计算类别损失
            pdcls = infer[:, 5:, :, :] # 预测类别特征，形状为[b,c,h,w]，计算交叉熵时会求sigmoid
            lbcls = label['lbcls']     # 预测类别标签，现状为[b,c,h,w]，计算交叉熵不会求sigmoid

            loss_cls = paddle.nn.functional.binary_cross_entropy_with_logits(
                pdcls, lbcls, reduction='none')     # 计算类别损失，形状为[b,c,h,w]
            loss_cls = loss_cls.sum(axis=1) * pmask # 只计算正样本，形状为[b,h,w]

            # 计算预测损失
            loss_sum = loss_iou + loss_obj + loss_cls   # 预测损失总和，形状为[b,h,w]
            loss_avg = loss_sum.sum(axis=[1, 2]).mean() # 预测平均损失，形状为[1]
            losses  += loss_avg                         # 预测损失求和，形状为[1]

        return losses

    def get_pdbox(self, pdloc, stride):
        """
        计算预测边框
        * 参数:
        - pdloc : 预测位置
        - stride: 下采样率
        * 返回:
        - pdbox: 预测边框
        """
        # 计算图像高宽
        grid_h = pdloc.shape[-2] # 网格高度
        grid_w = pdloc.shape[-1] # 网格宽度
        imgs_h = grid_h * stride # 图像高度
        imgs_w = grid_w * stride # 图像宽度

        # 计算网格坐标
        yv, xv = paddle.meshgrid([paddle.arange(grid_h), paddle.arange(grid_w)])
        grid_o = paddle.stack([xv, yv], axis=-1).astype('float32') # 网格原点，形状为[h,w,2]
        grid_x = grid_o[:, :, 0].unsqueeze(0)                      # 网格坐标，形状为[1,h,w]
        grid_y = grid_o[:, : ,1].unsqueeze(0)                      # 网格坐标，形状为[1,h,w]

        # 计算预测边框
        pd_tx = pdloc[:, 0, :, :] # 预测位置中心，形状为[b,h,w]
        pd_ty = pdloc[:, 1, :, :] # 预测位置中心，形状为[b,h,w]
        pd_tw = pdloc[:, 2, :, :] # 预测位置宽度，形状为[b,h,w]
        pd_th = pdloc[:, 3, :, :] # 预测位置高度，形状为[b,h,w]

        pd_px = (grid_x + paddle.nn.functional.sigmoid(pd_tx)) / grid_w # px=cx+sigmoid(tx)
        pd_py = (grid_y + paddle.nn.functional.sigmoid(pd_ty)) / grid_h # py=cy+sigmoid(ty)
        pd_pw = paddle.exp(pd_tw) / imgs_w                              # pw=exp(tw)
        pd_ph = paddle.exp(pd_th) / imgs_h                              # ph=exp(th)

        pdbox = paddle.stack([pd_px, pd_py, pd_pw, pd_ph], axis=1) # 计算预测边框，形状为[b,4,h,w]

        return pdbox